import json
import random
import re
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader, Dataset
import torch
from sklearn.model_selection import train_test_split
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"

class PretrainDataset(Dataset):
    def __init__(self, df, tokenizer, max_length=512):
        super().__init__()
        self.df = df
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.padding = 0

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, index: int):
        sample = self.df.iloc[index]
        text = f'{self.tokenizer.bos_token}{str(sample["text"])}{self.tokenizer.eos_token}'
        input_id = self.tokenizer(text).data['input_ids'][:self.max_length]
        text_len = len(input_id)
        padding_len = self.max_length - text_len
        input_id = input_id + [self.padding] * (padding_len)
        loss_mask = [1] * text_len + [0] * padding_len
        input_id = np.array(input_id)
        X = np.array(input_id[:-1]).astype(np.int64)
        y = np.array(input_id[1:]).astype(np.int64)
        loss_mask = np.array(loss_mask[1:]).astype(np.int64)
        return torch.from_numpy(X), torch.from_numpy(y), torch.from_numpy(loss_mask)

class SFTDataset(Dataset):
    def __init__(self, df, tokenizer, max_length=1024, prompt_max_len=512, answer_max_len=256):
        super().__init__()
        self.df = df
        self.tokenizer = tokenizer
        self.prompt_max_len = prompt_max_len
        self.answer_max_len = answer_max_len
        self.max_length = max_length

        self.padding = 0
        self.bos_id = self.tokenizer('<s>assistant').data["input_ids"]

    def __len__(self):
        return self.df.shape[0]

    def find_sublist_index(self, main_list, sub_list):
        last_index = 0
        for i in range(len(main_list) - len(sub_list) + 1):
            if main_list[i:i + len(sub_list)] == sub_list:
                last_index = i
        return last_index

    def safe_eval(self, s):
        try:
            res = eval(s)
        except Exception as e:
            return []
        return res

    def __getitem__(self, index):
        sample = self.df.iloc[index]
        history = self.safe_eval(sample["history"])
        q = str(sample["q"])
        a = str(sample["a"])
        messages = []
        for history_message in history:
            if len(history_message) <= 1:
                continue
            messages.append({"role": "user", "content": str(history_message[0])[:self.max_length // 2]})
            messages.append({"role": "assistant", "content": history_message[1][:self.max_length // 2]})
        messages += [{"role": "user", "content": q}, {"role": "assistant", "content": a}]
        new_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_special_tokens=True)
        input_id = self.tokenizer(new_prompt).data["input_ids"][:self.max_length]
        question_length = self.find_sublist_index(input_id, self.bos_id) + len(self.bos_id)
        padding_len = self.max_length - len(input_id)
        input_id = input_id + [self.padding] * padding_len
        mask_len = len(input_id) - question_length - padding_len
        loss_mask = [0] * question_length + [1] * mask_len + [0] * padding_len
        input_id = np.array(input_id)
        X = np.array(input_id[:-1]).astype(np.int64)
        Y = np.array(input_id[1:]).astype(np.int64)
        loss_mask = np.array(loss_mask[1:]).astype(np.int64)
        return torch.from_numpy(X), torch.from_numpy(Y), torch.from_numpy(loss_mask)


if __name__ == '__main__':

    from transformers import AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained('model')
    # df = pd.read_csv("dataset/pretrain_data.csv")
    # df = df.sample(frac=0.1)
    # train_ds = PretrainDataset(df, tokenizer, 128)
    # train_ds

    data_path = 'dataset/sft_data_single.csv'
    df = pd.read_csv(data_path)  # 加载训练数据
    df = df.sample(frac=1)  # 随机采样10%的数据
    train_ds = SFTDataset(df, tokenizer, 512)
    train_loader = DataLoader(
        train_ds,
        batch_size=64,
        pin_memory=False,
        drop_last=False,
        shuffle=False,
        num_workers=1,
        sampler=None,
    )  # 创建数据加载器
    for step, (X, Y, loss_mask) in enumerate(train_loader):
        print(X.shape, Y.shape, loss_mask.shape)